"""
Build Paper: Generate all figures and compile .docx
=====================================================
Generates:
  - 7 figures in output/figures/
  - paper/paper.docx with full text, tables, and figures

Requires: matplotlib, python-docx, pandas, numpy
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path
from docx import Document
from docx.shared import Inches, Pt, Cm, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_TABLE_ALIGNMENT
import re

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
TABLE_DIR = JAPAN_DIR / "output" / "tables"
FIG_DIR = JAPAN_DIR / "output" / "figures"
PAPER_DIR = JAPAN_DIR / "paper"

FIG_DIR.mkdir(parents=True, exist_ok=True)

# Style
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 11,
    'axes.titlesize': 13,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.bbox': 'tight',
    'savefig.dpi': 300,
})

COLORS = {
    'primary': '#2c3e50',
    'accent1': '#e74c3c',
    'accent2': '#3498db',
    'accent3': '#2ecc71',
    'accent4': '#f39c12',
    'accent5': '#9b59b6',
    'gray': '#95a5a6',
    'light_gray': '#ecf0f1',
}


# =====================================================================
# FIGURE 1: Age-Group Japanification Profile
# =====================================================================
def fig1_age_profile():
    print("  Figure 1: Age-group profile")
    df = pd.read_csv(TABLE_DIR / "phase3_age_coefficients.csv")

    fig, ax = plt.subplots(figsize=(10, 5.5))

    x = range(len(df))
    colors = [COLORS['accent1'] if p < 0.05 else COLORS['accent2'] if p < 0.1
              else COLORS['gray'] for p in df['p_value']]

    bars = ax.bar(x, df['coefficient'], yerr=1.96 * df['std_error'],
                  color=colors, edgecolor='white', linewidth=0.5,
                  capsize=3, error_kw={'linewidth': 1, 'color': '#555'})

    ax.set_xticks(x)
    ax.set_xticklabels(df['age_group'], rotation=45, ha='right')
    ax.set_ylabel('Coefficient on Japanification Index')
    ax.set_title('Figure 1: Age-Group Japanification Profile')
    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.set_xlabel('Age Group')

    # Legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor=COLORS['accent1'], label='p < 0.05'),
        Patch(facecolor=COLORS['accent2'], label='p < 0.10'),
        Patch(facecolor=COLORS['gray'], label='Not significant'),
    ]
    ax.legend(handles=legend_elements, loc='upper left')

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig1_age_profile.png")
    plt.close()


# =====================================================================
# FIGURE 2: Rolling Window Z₁ Coefficient (Composite Index)
# =====================================================================
def fig2_rolling_windows():
    print("  Figure 2: Rolling windows")
    df = pd.read_csv(TABLE_DIR / "phase4_rolling_windows.csv")

    fig, ax = plt.subplots(figsize=(10, 5.5))

    mid_year = (df['window_start'] + df['window_end']) / 2

    ax.plot(mid_year, df['Z_1_coef'], 'o-', color=COLORS['primary'],
            linewidth=2, markersize=5, label='Z₁ coefficient')

    # Shade significance
    sig = df['Z_1_pval'] < 0.1
    ax.fill_between(mid_year, df['Z_1_coef'] - 1.96 * (df['Z_1_coef'] / 1.645 * df['Z_1_pval'].apply(
        lambda p: max(p, 0.001))),  # approximate SE from p-value
        df['Z_1_coef'] + 1.96 * (df['Z_1_coef'] / 1.645 * df['Z_1_pval'].apply(
        lambda p: max(p, 0.001))),
        alpha=0.15, color=COLORS['accent2'])

    # Mark significant points
    ax.scatter(mid_year[sig], df.loc[sig, 'Z_1_coef'],
               color=COLORS['accent1'], s=60, zorder=5, label='p < 0.10')

    ax.axhline(y=0, color='black', linewidth=0.8, linestyle='--')
    ax.axvline(x=2008, color=COLORS['accent4'], linewidth=1.5, linestyle=':',
               label='GFC (2008)')

    ax.set_xlabel('Window Midpoint')
    ax.set_ylabel('Z₁ Coefficient')
    ax.set_title('Figure 2: Rolling 15-Year Window Z₁ Coefficient')
    ax.legend(loc='upper right')

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig2_rolling_windows.png")
    plt.close()


# =====================================================================
# FIGURE 3: Component-by-Component Rolling Windows
# =====================================================================
def fig3_component_rolling():
    print("  Figure 3: Component rolling windows")
    df = pd.read_csv(TABLE_DIR / "phase4b_component_rolling.csv")

    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5), sharey=False)

    for ax, comp, title, color in zip(axes,
            ['growth', 'inflation', 'rate'],
            ['Growth', 'Inflation', 'Interest Rate'],
            [COLORS['accent1'], COLORS['accent2'], COLORS['accent3']]):

        sub = df[df['component'] == comp]
        mid = (sub['window_start'] + sub['window_end']) / 2

        ax.plot(mid, sub['Z_1_coef'], 'o-', color=color, linewidth=2, markersize=4)

        sig = sub['Z_1_pval'] < 0.1
        ax.scatter(mid[sig.values], sub.loc[sig.values, 'Z_1_coef'],
                   color=color, s=50, edgecolors='black', linewidth=0.8, zorder=5)

        ax.axhline(y=0, color='black', linewidth=0.8, linestyle='--')
        ax.axvline(x=2008, color=COLORS['accent4'], linewidth=1, linestyle=':')
        ax.set_title(f'{title}')
        ax.set_xlabel('Window Midpoint')
        if ax == axes[0]:
            ax.set_ylabel('Z₁ Coefficient')

    fig.suptitle('Figure 3: Component-by-Component Rolling Window Z₁',
                 fontsize=13, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig3_component_rolling.png")
    plt.close()


# =====================================================================
# FIGURE 4: Country Projection Timelines
# =====================================================================
def fig4_projections():
    print("  Figure 4: Country projections")
    df = pd.read_csv(TABLE_DIR / "phase5_country_timeline.csv")

    # Select interesting countries
    focus = ['JPN', 'KOR', 'DEU', 'ITA', 'USA', 'CHN', 'GBR', 'IND', 'BRA', 'THA']
    years = [2030, 2040, 2050]

    fig, ax = plt.subplots(figsize=(11, 6))

    country_names = {
        'JPN': 'Japan', 'KOR': 'Korea', 'DEU': 'Germany', 'ITA': 'Italy',
        'USA': 'United States', 'CHN': 'China', 'GBR': 'United Kingdom',
        'IND': 'India', 'BRA': 'Brazil', 'THA': 'Thailand',
    }

    palette = plt.cm.tab10(np.linspace(0, 1, len(focus)))

    for i, iso3 in enumerate(focus):
        row = df[df['iso3'] == iso3]
        if len(row) == 0:
            continue
        row = row.iloc[0]
        vals = [row['current_index']] + [row[f'proj_{y}'] for y in years]
        yr_pts = [2020] + years
        # Skip NaN values
        valid = [(y, v) for y, v in zip(yr_pts, vals) if pd.notna(v)]
        if valid:
            yrs, vs = zip(*valid)
            ax.plot(yrs, vs, 'o-', color=palette[i], linewidth=2,
                    markersize=5, label=country_names.get(iso3, iso3))

    # Threshold line
    ax.axhline(y=0.226, color=COLORS['accent1'], linewidth=1.5, linestyle='--',
               label='Japan ~2000 threshold')

    ax.set_xlabel('Year')
    ax.set_ylabel('Japanification Index')
    ax.set_title('Figure 4: Projected Japanification Trajectories')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=9)
    ax.set_xlim(2018, 2052)

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig4_projections.png")
    plt.close()


# =====================================================================
# FIGURE 5: OADR Threshold
# =====================================================================
def fig5_oadr_threshold():
    print("  Figure 5: OADR threshold")
    df = pd.read_csv(TABLE_DIR / "phase4_oadr_thresholds.csv")

    fig, ax = plt.subplots(figsize=(8, 5))

    x = df['knot'] * 100
    width = 2.0

    bars_below = ax.bar(x - width/2, df['coef_below'], width,
                         color=COLORS['accent2'], alpha=0.7, label='Below knot')
    bars_above = ax.bar(x + width/2, df['coef_above'], width,
                         color=COLORS['accent1'], alpha=0.7, label='Above knot')

    # Add significance stars
    for i, row in df.iterrows():
        for offset, coef, p in [(- width/2, row['coef_below'], row['p_below']),
                                 (width/2, row['coef_above'], row['p_above'])]:
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
            if sig:
                y_pos = coef + 0.05 if coef >= 0 else coef - 0.15
                ax.text(row['knot']*100 + offset, y_pos, sig,
                        ha='center', fontsize=10, fontweight='bold')

    ax.set_xlabel('OADR Knot Point (%)')
    ax.set_ylabel('Coefficient on OADR')
    ax.set_title('Figure 5: OADR Spline Regression — Threshold Test')
    ax.set_xticks(x)
    ax.set_xticklabels([f'{int(k)}%' for k in x])
    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.legend()

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig5_oadr_threshold.png")
    plt.close()


# =====================================================================
# FIGURE 6: Income Quartile Z₁ Coefficients
# =====================================================================
def fig6_income_quartiles():
    print("  Figure 6: Income quartiles")
    df = pd.read_csv(TABLE_DIR / "phase4c_income_quartiles.csv")

    fig, ax = plt.subplots(figsize=(8, 5))

    x = range(len(df))
    colors = [COLORS['accent1'] if p < 0.01 else COLORS['accent2'] if p < 0.1
              else COLORS['gray'] for p in df['Z_1_pval']]

    bars = ax.bar(x, df['Z_1_coef'], color=colors, edgecolor='white', linewidth=0.5)

    labels = [f'Q{i+1}\n(${int(m):,})' for i, m in enumerate(df['median_gdp_pc'])]
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylabel('Z₁ Coefficient')
    ax.set_xlabel('Income Quartile (median GDP/cap PPP)')
    ax.set_title('Figure 6: Z₁ Coefficient by Income Quartile')
    ax.axhline(y=0, color='black', linewidth=0.8)

    # Add p-values
    for i, (_, row) in enumerate(df.iterrows()):
        sig = '***' if row['Z_1_pval'] < 0.01 else '**' if row['Z_1_pval'] < 0.05 else '*' if row['Z_1_pval'] < 0.1 else ''
        y_pos = row['Z_1_coef'] + 0.15 if row['Z_1_coef'] >= 0 else row['Z_1_coef'] - 0.35
        ax.text(i, y_pos, f'p={row["Z_1_pval"]:.3f}{sig}',
                ha='center', fontsize=9)

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig6_income_quartiles.png")
    plt.close()


# =====================================================================
# FIGURE 7: GE Trajectory — Global Japanification Timeline
# =====================================================================
def fig7_ge_trajectory():
    print("  Figure 7: GE trajectory")
    df = pd.read_csv(TABLE_DIR / "phase5_ge_trajectory.csv")

    fig, ax1 = plt.subplots(figsize=(9, 5))

    ax1.bar(df['year'], df['share_above'] * 100,
            color=COLORS['accent1'], alpha=0.6, width=4, label='Share above threshold (%)')
    ax1.set_ylabel('Share of Countries Above Threshold (%)', color=COLORS['accent1'])
    ax1.tick_params(axis='y', labelcolor=COLORS['accent1'])
    ax1.set_ylim(0, 70)

    ax2 = ax1.twinx()
    ax2.plot(df['year'], df['mean_japan_index'], 's-', color=COLORS['primary'],
             linewidth=2, markersize=6, label='Mean Japanification Index')
    ax2.set_ylabel('Mean Japanification Index', color=COLORS['primary'])
    ax2.tick_params(axis='y', labelcolor=COLORS['primary'])

    ax1.set_xlabel('Year')
    ax1.set_title('Figure 7: Global Japanification Trajectory (24 Focus Countries)')

    # Combined legend
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig7_ge_trajectory.png")
    plt.close()


# =====================================================================
# BUILD .DOCX
# =====================================================================
def build_docx():
    print("\nBuilding .docx ...")
    doc = Document()

    # --- Styles ---
    style = doc.styles['Normal']
    font = style.font
    font.name = 'Times New Roman'
    font.size = Pt(11)

    pf = style.paragraph_format
    pf.space_after = Pt(6)
    pf.space_before = Pt(0)

    # Read the paper markdown
    # Try dated filenames, fall back to paper.md
    for name in ["japanification_paper_20260221_v2.md", "japanification_paper_20260221.md", "japanification_paper_20260220.md", "paper.md"]:
        paper_path = PAPER_DIR / name
        if paper_path.exists():
            break
    paper_text = paper_path.read_text(encoding='utf-8')
    lines = paper_text.split('\n')

    # Track headings for figure/table insertion
    figure_after = {
        '5.5 Age-Group Profile': ['fig1_age_profile.png'],
        '6. The Post-GFC Structural Break': ['fig2_rolling_windows.png'],
        '6.2 Which Component Broke?': ['fig3_component_rolling.png'],
        '7.5 Income Quartile Analysis': ['fig6_income_quartiles.png'],
        '8.2 Country Timelines': ['fig4_projections.png'],
        '8.3 Global Trajectory': ['fig7_ge_trajectory.png'],
        '5.3 OADR Threshold': ['fig5_oadr_threshold.png'],
    }

    def add_docx_table(doc, title, headers, data, note=None):
        """Helper to add a formatted table."""
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
        run = p.add_run(title)
        run.bold = True
        run.font.size = Pt(10)

        table = doc.add_table(rows=1 + len(data), cols=len(headers))
        table.style = 'Light Shading'
        table.alignment = WD_TABLE_ALIGNMENT.CENTER

        for j, h in enumerate(headers):
            cell = table.rows[0].cells[j]
            cell.text = h
            for paragraph in cell.paragraphs:
                for run in paragraph.runs:
                    run.bold = True
                    run.font.size = Pt(9)

        for r, row_data in enumerate(data):
            for c, val in enumerate(row_data):
                cell = table.rows[r+1].cells[c]
                cell.text = val
                for paragraph in cell.paragraphs:
                    for run in paragraph.runs:
                        run.font.size = Pt(9)

        if note:
            p = doc.add_paragraph(note)
            p.paragraph_format.space_before = Pt(2)
            for run in p.runs:
                run.font.size = Pt(8)
                run.italic = True

    def insert_tables_for(doc, heading_text):
        """Insert relevant tables after a heading."""
        if '5.1 Baseline' in heading_text:
            add_docx_table(doc,
                'Table 1: Baseline Japanification Regressions',
                ['Variable', 'Model 1 (2c)', 'Model 1b (3c)', 'Model 2 (KAOPEN)'],
                [
                    ['Z1', '0.965', '0.113', '0.889'],
                    ['', '(0.601)', '(0.373)', '(0.626)'],
                    ['Z2', '-0.160*', '-0.032', '-0.137'],
                    ['', '(0.090)', '(0.055)', '(0.096)'],
                    ['Z3', '0.007*', '0.002', '0.006'],
                    ['', '(0.004)', '(0.002)', '(0.004)'],
                    ['fiscal_bal_gdp', '-0.014***', '-0.007***', '-0.014***'],
                    ['kaopen', '0.033***', '0.021***', '0.067***'],
                    ['R-sq', '0.071', '0.158', '0.075'],
                    ['N obs', '3,520', '2,522', '3,520'],
                    ['N countries', '136', '108', '136'],
                ],
                'Standard errors in parentheses. *** p<0.01, ** p<0.05, * p<0.1')

        elif '5.2 Component' in heading_text:
            add_docx_table(doc,
                'Table 2: Component Regressions',
                ['Variable', 'Growth', 'Inflation', 'Rate'],
                [
                    ['Z1', '-11.504', '-11.480', '-23.451'],
                    ['Z2', '1.818*', '3.288', '2.886'],
                    ['Z3', '-0.079*', '-0.163', '-0.102'],
                    ['R-sq', '0.049', '-0.007', '0.166'],
                    ['N', '3,520', '3,520', '2,522'],
                ])

        elif '6.1 Rolling' in heading_text:
            add_docx_table(doc,
                'Table 3: Pre/Post GFC Structural Break',
                ['Variable', 'Pre-GFC (1990-2007)', 'Post-GFC (2009-2024)'],
                [
                    ['Z1', '2.730*** (0.998)', '-0.927 (0.587)'],
                    ['Z2', '-0.452*** (0.152)', '0.109 (0.086)'],
                    ['Z3', '0.020*** (0.006)', '-0.004 (0.003)'],
                    ['R-sq', '0.057', '0.169'],
                    ['N', '1,895', '1,491'],
                ])

        elif '7.5 Income Quartile' in heading_text:
            add_docx_table(doc,
                'Table 4: Z1 Coefficient by Income Quartile',
                ['Quartile', 'Median GDP/cap', 'Z1', 'p-value', 'N'],
                [
                    ['Q1', '$2,458', '0.90', '0.527', '871'],
                    ['Q2', '$7,867', '5.41', '<0.001', '821'],
                    ['Q3', '$18,569', '-0.87', '0.485', '800'],
                    ['Q4', '$49,233', '1.13', '0.131', '974'],
                ])

        elif '7.3 The Demographic Dividend' in heading_text:
            add_docx_table(doc,
                'Table 5: Working-Age Share Decomposition (Low-Income)',
                ['Variable', 'Without WAS', 'With WAS'],
                [
                    ['Z1', '3.209*** (1.189)', '0.236 (1.803)'],
                    ['working_age_share', '--', '6.794** (3.121)'],
                    ['R-sq', '0.002', '0.008'],
                ])

        elif '5.3 OADR Threshold' in heading_text:
            add_docx_table(doc,
                'Table 6: OADR Spline Regression Results',
                ['Knot', 'Below', 'p', 'Above', 'p', 'R-sq'],
                [
                    ['15%', '-0.05', '0.91', '1.56***', '<0.001', '0.073'],
                    ['20%', '0.42', '0.13', '1.71***', '0.001', '0.072'],
                    ['25%', '0.71***', '0.001', '1.51*', '0.070', '0.071'],
                    ['30%', '0.80***', '<0.001', '1.03', '0.491', '0.070'],
                ])

    def insert_figures_for(doc, heading_text):
        """Insert relevant figures after a heading."""
        for key, figs in figure_after.items():
            if key in heading_text:
                for fig_name in figs:
                    fig_path = FIG_DIR / fig_name
                    if fig_path.exists():
                        p = doc.add_paragraph()
                        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                        run = p.add_run()
                        run.add_picture(str(fig_path), width=Inches(5.5))

    def clean_md(text):
        """Strip markdown formatting."""
        text = re.sub(r'\$\$[^$]*\$\$', '', text)
        text = re.sub(r'\$([^$]+)\$', r'\1', text)
        text = re.sub(r'\*\*([^*]+)\*\*', r'\1', text)
        text = re.sub(r'\*([^*]+)\*', r'\1', text)
        text = re.sub(r'@\w+', '', text)
        return text.strip()

    # --- Process line by line ---
    i = 0
    in_math = False
    math_buf = []

    while i < len(lines):
        line = lines[i]
        i += 1

        # Skip frontmatter separators
        if line.strip() == '---':
            continue

        # Skip markdown tables (we insert our own formatted ones)
        if line.strip().startswith('|'):
            continue

        # Multi-line math blocks
        if line.strip() == '$$':
            if in_math:
                # End of math block
                p = doc.add_paragraph()
                p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                run = p.add_run(' '.join(math_buf))
                run.italic = True
                run.font.size = Pt(10)
                math_buf = []
                in_math = False
            else:
                in_math = True
            continue

        if in_math:
            math_buf.append(line.strip())
            continue

        # Inline math on its own line ($$...$$)
        if line.strip().startswith('$$') and line.strip().endswith('$$'):
            content = line.strip()[2:-2].strip()
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run(content)
            run.italic = True
            run.font.size = Pt(10)
            continue

        # Title
        if line.startswith('# ') and not line.startswith('## '):
            title = line[2:].strip()
            if 'Japanification' in title and "Who" in title:
                p = doc.add_paragraph()
                p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                run = p.add_run(title)
                run.bold = True
                run.font.size = Pt(18)

                p = doc.add_paragraph()
                p.alignment = WD_ALIGN_PARAGRAPH.CENTER
                run = p.add_run('Working Paper \u2014 February 2026')
                run.font.size = Pt(12)
                run.italic = True
            else:
                doc.add_heading(title, level=1)
            continue

        # H2
        if line.startswith('## '):
            title = line[3:].strip()
            doc.add_heading(title, level=2)
            # Insert tables and figures after relevant section headings
            insert_tables_for(doc, title)
            insert_figures_for(doc, title)
            continue

        # H3
        if line.startswith('### '):
            title = line[4:].strip()
            doc.add_heading(title, level=3)
            insert_tables_for(doc, title)
            insert_figures_for(doc, title)
            continue

        # Empty line
        if not line.strip():
            continue

        # Paragraph text — collect continuation lines
        para_lines = [line]
        while i < len(lines):
            next_line = lines[i]
            # Stop at headings, empty lines, tables, math blocks
            if (not next_line.strip() or next_line.startswith('#') or
                next_line.strip().startswith('|') or next_line.strip() == '$$' or
                (next_line.strip().startswith('$$') and next_line.strip().endswith('$$'))):
                break
            para_lines.append(next_line)
            i += 1

        text = ' '.join(l.strip() for l in para_lines)
        text = clean_md(text)

        if text:
            doc.add_paragraph(text)

    # Save
    docx_path = PAPER_DIR / "paper.docx"
    doc.save(str(docx_path))
    print(f"  Saved: {docx_path}")


def main():
    print("=" * 70)
    print("Building Figures and Paper .docx")
    print("=" * 70)

    print("\nGenerating figures ...")
    fig1_age_profile()
    fig2_rolling_windows()
    fig3_component_rolling()
    fig4_projections()
    fig5_oadr_threshold()
    fig6_income_quartiles()
    fig7_ge_trajectory()

    print(f"\n  All figures saved to {FIG_DIR}")

    build_docx()

    print(f"\n{'=' * 70}")
    print("Done.")
    print("=" * 70)


if __name__ == "__main__":
    main()
